Example: Statistical inference for ODE-based infectious disease models

Introduction

What are we going to do in this Vignette

In this vignette, we'll demonstrate how to use EpiAware in conjunction with SciML ecosystem for Bayesian inference of infectious disease dynamics. The model and data is heavily based on Contemporary statistical inference for infectious disease models using Stan Chatzilena et al. 2019.

We'll cover the following key points:

  1. Defining the deterministic ODE model from Chatzilena et al section 2.2.2 using SciML ODE functionality and an EpiAware observation model.

  2. Build on this to define the stochastic ODE model from Chatzilena et al section 2.2.3 using an EpiAware observation model.

  3. Fitting the deterministic ODE model to data from an Influenza outbreak in an English boarding school.

  4. Fitting the stochastic ODE model to data from an Influenza outbreak in an English boarding school.

What might I need to know before starting

This vignette builds on concepts from EpiAware observation models and a familarity with the SciML and Turing ecosystems would be useful but not essential.

Packages used in this vignette

Alongside the EpiAware package we will use the OrdinaryDiffEq and SciMLSensitivity packages for interfacing with SciML ecosystem; this is a lower dependency usage of DifferentialEquations.jl that, respectively, exposes ODE solvers and adjoint methods for ODE solvees; that is the method of propagating parameter derivatives through functions containing ODE solutions. Bayesian inference will be done with NUTS from the Turing ecosystem. We will also use the CairoMakie package for plotting and DataFramesMeta for data manipulation.

using EpiAware
using Turing
using OrdinaryDiffEq, SciMLSensitivity #ODE solvers and adjoint methods
using Distributions, Statistics #Statistics packages
using CSV, DataFramesMeta #Data wrangling
using CairoMakie, PairPlots
using ReverseDiff #Automatic differentiation backend
begin #Date utility and set Random seed
    using Dates
    using Random
    Random.seed!(1234)
end
TaskLocalRNG()

Single population SIR model

As mentioned in Chatzilena et al disease spread is frequently modelled in terms of ODE-based models. The study population is divided into compartments representing a specific stage of the epidemic status. In this case, susceptible, infected, and recovered individuals.

$$\begin{aligned} {dS \over dt} &= - \beta \frac{I(t)}{N} S(t) \\ {dI \over dt} &= \beta \frac{I(t)}{N} S(t) - \gamma I(t) \\ {dR \over dt} &= \gamma I(t). \\ \end{aligned}$$

where S(t) represents the number of susceptible, I(t) the number of infected and R(t) the number of recovered individuals at time t. The total population size is denoted by N (with N = S(t) + I(t) + R(t)), β denotes the transmission rate and γ denotes the recovery rate.

We can interface to the SciML ecosystem by writing a function with the signature:

(du, u, p, t) -> nothing

Where:

  • du is the vector field of the ODE problem, e.g. \({dS \over dt}\), \({dI \over dt}\) etc. This is calculated in-place (commonly denoted using ! in function names in Julia).

  • u is the state of the ODE problem, e.g. \(S\), \(I\), etc.

  • p is an object that represents the parameters of the ODE problem, e.g. \(\beta\), \(\gamma\).

  • t is the time of the ODE problem.

We do this for the SIR model described above in a function called sir!:

function sir!(du, u, p, t)
    S, I, R = u
    β, γ = p
    du[1] = -β * I * S
    du[2] = β * I * S - γ * I
    du[3] = γ * I

    return nothing
end
sir! (generic function with 1 method)

We combine vector field function sir! with a initial condition u0 and the integration period tspan to make an ODEProblem. We do not define the parameters, these will be defined within an inference approach.

sir_prob = ODEProblem(
    sir!,
    N .* [0.99, 0.01, 0.0],
    (0.0, (Date(1978, 2, 4) - Date(1978, 1, 22)).value + 1)
)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 14.0)
u0: 3-element Vector{Float64}:
 755.37
   7.63
   0.0

Note that this is analogous to the EpiProblem approach we expose from EpiAware, as used in the Mishra et al replication. The difference is that here we are going to use ODE solvers from the SciML ecosystem to generate the dynamics of the underlying infections. In the linked example, we use latent process generation exposed by EpiAware as the underlying generative process for underlying dynamics.

Data for inference

There was a brief, but intense, outbreak of Influenza within the (semi-) closed community of a boarding school reported to the British medical journal in 1978. The outbreak lasted from 22nd January to 4th February and it is reported that one infected child started the epidemic and then it spread rapidly. Of the 763 children at the boarding scholl, 512 became ill.

We downloaded the data of this outbreak using the R package outbreaks which is maintained as part of the R Epidemics Consortium(RECON).

data = "https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/refs/heads/446-add-chatzilena-et-al-as-a-replication-example/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2" |>
       url -> CSV.read(download(url), DataFrame) |>
              df -> @transform(df,
    :ts=(:date .- minimum(:date)) .|> d -> d.value + 1.0,)
Column1datein_bedconvalescentts
111978-01-22301.0
221978-01-23802.0
331978-01-242603.0
441978-01-257604.0
551978-01-2622595.0
661978-01-27298176.0
771978-01-282581057.0
881978-01-292331628.0
991978-01-301891769.0
10101978-01-3112816610.0
11111978-02-016815011.0
12121978-02-02298512.0
13131978-02-03144713.0
14141978-02-0442014.0
N = 763;

Inference for the deterministic SIR model

The boarding school data gives the number of children "in bed" and "convalescent" on each of 14 days from 22nd Jan to 4th Feb 1978. We follow Chatzilena et al and treat the number "in bed" as a proxy for the number of children in the infectious (I) compartment in the ODE model.

The full observation model is:

$$\begin{aligned} Y_t &\sim \text{Poisson}(\lambda_t)\\ \lambda_t &= I(t)\\ \beta &\sim \text{LogNormal}(\text{logmean}=0,\text{logstd}=1) \\ \gamma & \sim \text{Gamma}(\text{shape} = 0.004, \text{scale} = 50)\\ S(0) /N &\sim \text{Beta}(0.5, 0.5). \end{aligned}$$

NB: Chatzilena et al give \(\lambda_t = \int_0^t \beta \frac{I(s)}{N} S(s) - \gamma I(s)ds = I(t) - I(0).\) However, this doesn't match their underlying stan code.

From EpiAware, we have the PoissonError struct which defines the probabilistic structure of this observation error model.

obs = PoissonError()
PoissonError()

Now we can write the observation model using the Turing PPL.

@model function deterministic_ode_mdl(Yt, ts, obs, prob, N;
        solver = AutoTsit5(Rosenbrock23()),
        upjitter = 1e-3
)
    ##Priors##
    β ~ LogNormal(0.0, 1.0)
    γ ~ Gamma(0.004, 1 / 0.002)
    S₀ ~ Beta(0.5, 0.5)

    ##remake ODE model##
    _prob = remake(prob;
        u0 = [S₀, 1 - S₀, 0.0],
        p = [β, γ]
    )

    ##Solve remade ODE model##

    sol = solve(_prob, solver;
        saveat = ts,
        verbose = false)

    ##log-like accumulation using obs##
    λt = N * sol[2, :] .+ upjitter #expected It
    @submodel obsYt = generate_observations(obs, Yt, λt)

    ##Generated quantities##
    return (; sol, obsYt, R0 = β / γ)
end
deterministic_ode_mdl (generic function with 2 methods)

We instantiate the model in two ways:

  1. deterministic_mdl: This conditions the generative model on the data observation.

We can sample from this model to find the posterior distribution of the parameters.

  1. deterministic_uncond_mdl: This doesn't condition on the data.

This is useful for prior and posterior predictive modelling.

deterministic_mdl = deterministic_ode_mdl(data.in_bed, data.ts, obs, sir_prob, N);
deterministic_uncond_mdl = deterministic_ode_mdl(
    fill(missing, length(data.in_bed)), data.ts, obs, sir_prob, N);

We add a useful plotting utility.

function plot_predYt(data, gens; title::String, ylabel::String)
    fig = Figure()
    ga = fig[1, 1:2] = GridLayout()

    ax = Axis(ga[1, 1];
        title = title,
        xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string),
        ylabel = ylabel
    )
    pred_Yt = mapreduce(hcat, gens) do gen
        gen.obsYt
    end |> X -> mapreduce(vcat, eachrow(X)) do row
        quantile(row, [0.5, 0.025, 0.975, 0.1, 0.9, 0.25, 0.75])'
    end

    lines!(ax, data.ts, pred_Yt[:, 1]; linewidth = 3, color = :green, label = "Median")
    band!(
        ax, data.ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.2), label = "95% CI")
    band!(
        ax, data.ts, pred_Yt[:, 4], pred_Yt[:, 5], color = (:green, 0.4), label = "80% CI")
    band!(
        ax, data.ts, pred_Yt[:, 6], pred_Yt[:, 7], color = (:green, 0.6), label = "50% CI")
    scatter!(ax, data.in_bed, label = "data")
    leg = Legend(ga[1, 2], ax; framevisible = false)
    hidespines!(ax)

    fig
end
plot_predYt (generic function with 1 method)

Prior predictive sampling

let
    prior_chn = sample(deterministic_uncond_mdl, Prior(), 2000)
    gens = generated_quantities(deterministic_uncond_mdl, prior_chn)
    plot_predYt(data, gens;
        title = "Prior predictive: deterministic model",
        ylabel = "Number of Infected students"
    )
end

The prior predictive checking suggests that a priori our parameter beliefs are very far from the data. Approaching the inference naively can lead to poor fits.

We do three things to mitigate this:

  1. We choose a switching ODE solver which switches between explicit (Tsit5) and implicit (Rosenbrock23) solvers. This helps avoid the ODE solver failing when the sampler tries extreme parameter values. This is the default solver = AutoTsit5(Rosenbrock23()) above.

  2. To avoid the effect of numerically negative small values of λt we add a small upjitter.

  3. We locate the maximum likelihood point, that is we ignore the influence of the priors, as a useful starting point for NUTS.

nmle_tries = 100
100
mle_fit = map(1:nmle_tries) do _
    fit = try
        maximum_likelihood(deterministic_mdl)
    catch
        (lp = -Inf,)
    end
end |>
          fits -> (findmax(fit -> fit.lp, fits)[2], fits) |>
                  max_and_fits -> max_and_fits[2][max_and_fits[1]]
ModeResult with maximized lp of -67.55
[1.892700300823729, 0.4807992676989487, 0.9995207039580596]
mle_fit.optim_result.retcode
ReturnCode.Success = 1

Note that we choose the best out of 100 tries for the MLE estimators.

Now, we sample aiming at 1000 samples for each of 4 chains.

chn = sample(
    deterministic_mdl, NUTS(), MCMCThreads(), 1000, 4;
    initial_params = fill(mle_fit.values.array, 4))
iterationchainβγS₀lpn_stepsis_accept...
150111.769950.4804050.999279-83.93615.01.0
250211.815410.4839760.999268-80.54557.01.0
350311.869490.4690150.999464-79.779915.01.0
450411.88650.4771980.999495-79.28717.01.0
550511.884580.4917590.999502-79.67077.01.0
650611.917170.5069630.999514-82.50137.01.0
750711.910330.5000470.999577-81.12827.01.0
850811.913930.4853220.999533-79.713615.01.0
950911.890740.4914480.99947-80.138315.01.0
1051011.887610.4709360.999456-80.73757.01.0
...
describe(chn)
2-element Vector{ChainDataFrame}:
 Summary Statistics (3 x 8)
 Quantiles (3 x 6)
pairplot(chn)

Posterior predictive plotting

let
    gens = generated_quantities(deterministic_uncond_mdl, chn)
    plot_predYt(data, gens;
        title = "Fitted deterministic model",
        ylabel = "Number of Infected students"
    )
end

Inference for the Stochastic SIR model

In Chatzilena et al, they present an auto-regressive model for connecting the outcome of the ODE model to illness observations. The argument is that the stochastic component of the model can absorb the noise generated by a possible mis-specification of the model.

In their approach they consider \(\kappa_t = \log \lambda_t\) where \(\kappa_t\) evolves according to an Ornstein-Uhlenbeck process:

$$d\kappa_t = \phi(\mu_t - \kappa_t) dt + \sigma dB_t.$$

Which has transition density:

$$\kappa_{t+1} | \kappa_t \sim N\Big(\mu_t + \left(\kappa_t - \mu_t\right)e^{-\phi}, {\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\Big).$$

Where \(\mu_t = \log(I(t))\).

We modify this approach since it implies that the \(\mu_t\) is treated as constant between observation times.

Instead we redefine \(\kappa_t\) as the log-residual:

$$\kappa_t = \log(\lambda_t / I(t)).$$

With the transition density:

$$\kappa_{t+1} | \kappa_t \sim N\Big(\kappa_te^{-\phi}, {\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\Big).$$

This is an AR(1) process.

The stochastic model is completed:

$$\begin{aligned} Y_t &\sim \text{Poisson}(\lambda_t)\\ \lambda_t &= I(t)\exp(\kappa_t)\\ \beta &\sim \text{LogNormal}(\text{logmean}=0,\text{logstd}=1) \\ \gamma & \sim \text{Gamma}(\text{shape} = 0.004, \text{scale} = 50)\\ S(0) /N &\sim \text{Beta}(0.5, 0.5)\\ \phi & \sim \text{HalfNormal}(0, 100) \\ 1 / \sigma^2 & \sim \text{InvGamma}(0.1,0.1). \end{aligned}$$

We will using the AR struct from EpiAware to define the auto-regressive process in this model which has a direct parameterisation of the AR model.

To convert from the formulation above we sample from the priors, and define HalfNormal priors based on the sampled prior means of \(e^{-\phi}\) and \({\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\). We also add a strong prior that \(\kappa_1 \approx 0\).

ϕs = rand(truncated(Normal(0, 100), lower = 0.0), 1000)
1000-element Vector{Float64}:
 173.84613119927994
  29.518553653808322
 234.29914391751802
  38.19535251952747
  66.62365073973159
  72.73303789047586
  42.49825458703761
   ⋮
  74.05686152554611
  30.257250194026454
  52.248951190112535
 282.12537789602675
  86.34379224015433
  34.05920866168606
σ²s = rand(InverseGamma(0.1, 0.1), 1000) .|> x -> 1 / x
1000-element Vector{Float64}:
 0.1706591427622955
 0.13297364123999544
 6.101966577965764e-7
 0.2893819681848476
 1.322381087725992e-5
 6.608159436905257e-14
 2.2022146256780077
 ⋮
 0.021735075270736363
 0.007669216970214308
 1.6657774849296185e-14
 9.85208774805459
 3.789360268704109e-5
 0.12976058255234915
sampled_AR_damps = ϕs .|> ϕ -> exp(-ϕ)
1000-element Vector{Float64}:
 3.1592538246247405e-76
 1.5144503307107885e-13
 1.7586308283732263e-102
 2.5820768645648703e-17
 1.1633653365494597e-29
 2.5848955096216586e-32
 3.493353558480045e-19
 ⋮
 6.878815994119324e-33
 7.235077760799627e-14
 2.0350206143612827e-23
 2.9819830798768062e-123
 3.1722505656900094e-38
 1.6153759778859835e-15
sampled_AR_stds = map(ϕs, σ²s) do ϕ, σ²
    (1 - exp(-2 * ϕ)) * σ² / (2 * ϕ)
end
1000-element Vector{Float64}:
 0.0004908338816199159
 0.002252373927251004
 1.3021743220952362e-9
 0.0037881829737911226
 9.924261677673111e-8
 4.542749504603446e-16
 0.025909471425088895
 ⋮
 0.000146745857325042
 0.00012673354189549592
 1.5940774379073527e-16
 0.017460477716551674
 2.1943443589809463e-7
 0.0019049265624647299

We define the AR(1) process by matching means of HalfNormal prior distributions for the damp parameters and std deviation parameter to the calculated the prior means from the Chatzilena et al definition.

ar = AR(
    damp_priors = [HalfNormal(mean(sampled_AR_damps))],
    std_prior = HalfNormal(mean(sampled_AR_stds)),
    init_priors = [Normal(0, 0.001)]
)
AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.0070030651982823525), 1)), HalfNormal{Float64}(μ=0.012836842862849867), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1)
@model function stochastic_ode_mdl(Yt, ts, logobsprob, obs, prob, N;
        solver = AutoTsit5(Rosenbrock23()),
        upjitter = 1e-2,
)

    ##Priors##
    β ~ LogNormal(0.0, 1.0)
    γ ~ Gamma(0.004, 1 / 0.002)
    S₀ ~ Beta(0.5, 0.5)

    ##Remake ODE model##
    _prob = remake(prob;
        u0 = [S₀, 1 - S₀, 0.0],
        p = [β, γ]
    )

    ##Solve ODE model##
    sol = solve(_prob, solver;
        saveat = ts,
        verbose = false
    )

    ##Sample the log-residual AR process##
    nobs = length(Yt)
    @submodel κₜ = generate_latent(logobsprob, nobs)
    λt = @. N * sol[2, :] * exp(κₜ) + upjitter

    ##log-like accumulation using obs##
    @submodel obsYt = generate_observations(obs, Yt, λt)

    ##Generated quantities##
    return (; sol, obsYt, R0 = β / γ)
end
stochastic_ode_mdl (generic function with 2 methods)
stochastic_mdl = stochastic_ode_mdl(
    data.in_bed,
    data.ts,
    ar,
    obs,
    sir_prob,
    N
)
DynamicPPL.Model{typeof(stochastic_ode_mdl), (:Yt, :ts, :logobsprob, :obs, :prob, :N), (:solver, :upjitter), (), Tuple{Vector{Int64}, Vector{Float64}, AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}, PoissonError, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Int64}, Tuple{CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}, Float64}, DynamicPPL.DefaultContext}(Main.var"workspace#17".stochastic_ode_mdl, (Yt = [3, 8, 26, 76, 225, 298, 258, 233, 189, 128, 68, 29, 14, 4], ts = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0], logobsprob = AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.0070030651982823525), 1)), HalfNormal{Float64}(μ=0.012836842862849867), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1), obs = PoissonError(), prob = ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}(ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(Main.var"workspace#17".sir!, LinearAlgebra.UniformScaling{Bool}(true), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), [755.37, 7.63, 0.0], (0.0, 14.0), SciMLBase.NullParameters(), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), N = 763), (solver = CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}((Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!)), AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}(Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!), 10, 3, 9//10, 9//10, 2, false, 5)), upjitter = 0.01), DynamicPPL.DefaultContext())
stochastic_uncond_mdl = stochastic_ode_mdl(
    fill(missing, length(data.in_bed)),
    data.ts,
    ar,
    obs,
    sir_prob,
    N
)
DynamicPPL.Model{typeof(stochastic_ode_mdl), (:Yt, :ts, :logobsprob, :obs, :prob, :N), (:solver, :upjitter), (), Tuple{Vector{Missing}, Vector{Float64}, AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}, PoissonError, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Int64}, Tuple{CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}, Float64}, DynamicPPL.DefaultContext}(Main.var"workspace#17".stochastic_ode_mdl, (Yt = [missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing], ts = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0], logobsprob = AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.0070030651982823525), 1)), HalfNormal{Float64}(μ=0.012836842862849867), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1), obs = PoissonError(), prob = ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}(ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(Main.var"workspace#17".sir!, LinearAlgebra.UniformScaling{Bool}(true), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), [755.37, 7.63, 0.0], (0.0, 14.0), SciMLBase.NullParameters(), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), N = 763), (solver = CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}((Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!)), AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}(Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!), 10, 3, 9//10, 9//10, 2, false, 5)), upjitter = 0.01), DynamicPPL.DefaultContext())

Prior predictive checking

let
    prior_chn = sample(stochastic_uncond_mdl, Prior(), 2000)
    gens = generated_quantities(stochastic_uncond_mdl, prior_chn)
    plot_predYt(data, gens;
        title = "Prior predictive: stochastic model",
        ylabel = "Number of Infected students"
    )
end

The prior predictive checking again shows misaligned prior beliefs; for example a priori without data we would not expect the median prediction of number of ill children as about 600 out of 763 after 1 day.

The latent process for the log-residuals \(\kappa_t\) doesn't make much sense without priors, so we look for a reasonable MAP point to start NUTS from. We do this by first making an initial guess which is a mixture of:

  1. The posterior averages from the deterministic model.

  2. The prior averages of the structure parameters of the AR(1) process.

  3. Zero for the time-varying noise underlying the AR(1) process.

rand(stochastic_mdl)
(β = 0.5738913978051624, γ = 1.3926672379545036e-65, S₀ = 0.09950960157424277, σ_AR = 0.00710185114139463, ar_init = [-0.00010805696366646121], damp_AR = [0.014752282958082038], ϵ_t = [0.36894057222592974, 0.015879367010213655, -1.2021084547505942, -0.4867263304639022, 0.1945625569382327, -0.45269634579778484, 1.5312241119430532, 0.3506430786260485, -0.45462794302641213, 0.5962344946907625, 0.12086654378623474, 0.6460856902803586, 1.0089899514829812])
initial_guess = [[mean(chn[:β]),
                     mean(chn[:γ]),
                     mean(chn[:S₀]),
                     mean(ar.std_prior),
                     mean(ar.init_prior)[1],
                     mean(ar.damp_prior)[1]]
                 zeros(13)]
19-element Vector{Float64}:
 1.8897056893371456
 0.48108056736368593
 0.9994964534413393
 0.012836842862849867
 0.0
 0.0070030651982823525
 0.0
 ⋮
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

Starting from the initial guess, the MAP point is calculated rapidly in one pass.

map_fit_stoch_mdl = maximum_a_posteriori(stochastic_mdl;
    adtype = AutoReverseDiff(),
    initial_params = initial_guess
)
ModeResult with maximized lp of -71.73
[1.8991973007454837, 0.48135108041055585, 0.9995349042045181, 0.018558849362902724, 1.5181088512512426e-6, 4.353491564466609e-5, 0.03732809978668891, 0.043694117781807006, -0.12386508274105254, 0.30798281159738977, -0.032824750386364496, -0.3877656318973242, 0.3360214839886304, 0.7077349748167989, 0.503343338152674, 0.03903889934050348, -0.2553007182666317, -0.25492752360961835, -0.25859899980438394]

Now we can run NUTS, sampling 1000 posterior draws per chain for 4 chains.

chn2 = sample(
    stochastic_mdl,
    NUTS(; adtype = AutoReverseDiff(true)),
    MCMCThreads(), 1000, 4;
    initial_params = fill(map_fit_stoch_mdl.values.array, 4)
)
iterationchainβγS₀σ_ARar_init[1]damp_AR[1]...
150111.949630.4821940.9996130.01713220.00059960.025503
250211.933670.4848010.9995990.0468069-0.0008162130.00159983
350311.855730.4783850.9994210.0381151-0.000805670.00696753
450411.946770.4646150.9996260.007220030.0002498540.0104259
550511.934560.4843290.9996090.0260551-0.0002462630.00109421
650611.919120.4806010.9995570.02921740.0007534640.0155966
750711.881570.4901760.9995170.0191062-0.0009191450.00354336
850811.864760.4803620.9993490.0328774-0.001078550.00456089
950911.985880.4784720.9996380.00817061-0.0003122050.00419099
1051011.892930.4942820.9994290.02432110.0004030610.00380944
...
describe(chn2)
2-element Vector{ChainDataFrame}:
 Summary Statistics (19 x 8)
 Quantiles (19 x 6)
pairplot(chn2[[:β, :γ, :S₀, :σ_AR, Symbol("ar_init[1]"), Symbol("damp_AR[1]")]])
let
    vars = mapreduce(vcat, 1:13) do i
        Symbol("ϵ_t[$i]")
    end
    pairplot(chn2[vars])
end
let
    gens = generated_quantities(stochastic_uncond_mdl, chn2)
    plot_predYt(data, gens;
        title = "Fitted stochastic model",
        ylabel = "Number of Infected students"
    )
end